from distutils.util import strtobool
import os
import sys
from typing import List, Optional, Tuple
import argparse
import json

import torch
from sklearn.model_selection import train_test_split
from torch.nn import L1Loss, MSELoss
from torch.utils.data import DataLoader

from Utils import logger, config as cfg, get_logger
from Utils.Constants import FileNamesConstants as Names
from ModelsUtils.ModelTrainer import ModelTrainer
from ModelsUtils.ModelCallbacks import ModelCheckpointCallback, EarlyStoppingCallback
from DataHandling.DataUtils import stats_folders_train_test_split
from DataHandling.AutoEncodeSingleDataset import AutoEncoderStatsDataset, AutoEncoderSingleMapsDataset
from Encoder.FeatureMapAutoEncoder import FeatureMapAutoEncoder


def setup_data(train_val_folders: List[str], val_size: Optional[float], test_size: Optional[float],
               kfold_folder_base: Optional[str], kfold_num: Optional[int], use_maps_dataset: bool, data_source: str):
    if kfold_folder_base is not None:
        kfold_folder = os.path.join(kfold_folder_base, f'{kfold_num}_{data_source}')
        all_files = list()
        for name in [Names.TRAIN_FILES, Names.VAL_FILES, Names.TEST_FILES]:
            if not os.path.exists(os.path.join(kfold_folder, name)):
                all_files.append(list())
                continue
            with open(os.path.join(kfold_folder, name), 'r') as f:
                all_files.append(json.load(f))
        train_files, val_files, test_files = all_files

    else:
        if use_maps_dataset:
            all_files = [os.path.join(curr_folder, curr_file)
                         for curr_folder in train_val_folders for curr_file in os.listdir(curr_folder)]
            if test_size != 0:
                train_val_files, test_files = train_test_split(all_files, test_size=test_size, random_state=cfg.seed)
                remaining_files = train_val_files
            else:
                remaining_files = all_files
                test_files = list()
            if val_size != 0:
                train_files, val_files = train_test_split(remaining_files, test_size=val_size, random_state=cfg.seed)
            else:
                train_files = remaining_files
                val_files = list()
        else:
            train_files, val_files = stats_folders_train_test_split(data_source == 'weights', train_val_folders, val_size)
            test_files = list()

    return train_files, val_files, test_files


def train_auto_encoder(train_val_folders: List[str], test_folders: Optional[List[str]], val_size: Optional[float],
                       test_size: Optional[float], kfold_folder_base: Optional[str], kfold_num: Optional[int],
                       epochs: int, batch_size: int, data_source: str,
                       is_single_layer: bool, model_save_path: str, layer_ops: Tuple[int], use_maps_dataset: bool = True,
                       pre_load_data: bool = True):
    """
    :param train_val_folders:
    :param test_folders:
    :param val_size:
    :param test_size:
    :param kfold_folder_base:
    :param kfold_num:
    :param epochs:
    :param batch_size:
    :param data_source:
    :param is_single_layer:
    :param model_save_path:
    :param layer_ops: number of layers index to use
    :param use_maps_dataset: True for using a pre made maps
    :param pre_load_data: True for loading all data in advance to memory (better performance)
    :return:
    """
    func_parameters = locals().copy()
    func_parameters['config'] = cfg.to_json()
    with open(os.path.join(model_save_path, Names.RUN_PARAMETERS), 'w') as jf:
        json.dump(func_parameters, jf)

    is_weights = data_source == 'weights'
    logger().log('AutoEncoderTraining::train_auto_encoder', 'Parameters: ', locals(), '\n\n\n')
    if use_maps_dataset:
        cfg.stats_used_in_feature_map = 12

    layers_sizes = [((1, cfg.feature_map_values_size), layer_ops[0]),
                    ((cfg.feature_map_layers_size, 1), layer_ops[1]),
                    ((1, 1), layer_ops[2])
                    ]
    logger().print_and_log('AutoEncoderTraining::train_auto_encoder',
                           f'Training Auto encoder with shapes: {layers_sizes}. Save path: {model_save_path}. '
                           f'Training on: {train_val_folders}, {data_source=}, {is_single_layer=}')
    ae_model = FeatureMapAutoEncoder(layers_sizes, model_save_path)

    train_files, val_files, test_files = setup_data(train_val_folders=train_val_folders, val_size=val_size,
                                                    test_size=test_size, data_source=data_source,
                                                    use_maps_dataset=use_maps_dataset,
                                                    kfold_folder_base=kfold_folder_base, kfold_num=kfold_num)

    for name, data in zip([Names.TRAIN_FILES, Names.VAL_FILES, Names.TEST_FILES], [train_files, val_files, test_files]):
        with open(os.path.join(model_save_path, name), 'w') as file:
            json.dump(data, file)

    if use_maps_dataset:
        train_dataset = AutoEncoderSingleMapsDataset.create_dataset(train_files, is_weights=is_weights,
                                                                    pre_load_data=pre_load_data)
    else:
        train_dataset = AutoEncoderStatsDataset.create_dataset(train_files, is_weights=is_weights,
                                                               is_single_layer=is_single_layer)
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
    if len(val_files) == 0:
        val_dataloader = None
        logger().log('AutoEncoderTraining::train_auto_encoder', 'Has no validation')
    else:
        if use_maps_dataset:
            val_dataset = AutoEncoderSingleMapsDataset.create_dataset(val_files, is_weights=is_weights,
                                                                      pre_load_data=pre_load_data)
        else:
            val_dataset = AutoEncoderStatsDataset.create_dataset(val_files, is_weights, is_single_layer)
        val_dataloader = DataLoader(val_dataset, batch_size, shuffle=True)

    metrics_funcs = [L1Loss(reduction='sum')]
    metrics_names = ['MAE']
    trainer = ModelTrainer(ae_model, ae_model.tb_writer(), loss_func=MSELoss(reduction='sum'), use_scheduler=True)
    logger().force_log_and_print('AutoEncoderTraining::train_auto_encoder', f'Start training Auto-Encoder for {epochs} epochs')
    trainer.fit(epochs, train_dataloader, val_dataloader, metrics_funcs=metrics_funcs, metrics_names=metrics_names,
                checkpoint_cb=ModelCheckpointCallback(os.path.join(model_save_path, 'cp'), use_loss=True, use_train=val_dataloader is None),
                early_stopping=EarlyStoppingCallback(patience=25, use_train=val_dataloader is None, use_loss=True))
    torch.save(ae_model.state_dict(), os.path.join(model_save_path, 'last_ae_model.pt'))
    logger().force_log_and_print('AutoEncoderTraining::train_auto_encoder', f'Finished training Auto-Encoder')

    if test_folders is not None and len(test_folders) > 0:
        test_files.extend([os.path.join(curr_folder, curr_file)
                           for curr_folder in test_folders
                           for curr_file in os.listdir(curr_folder)])
    if test_files is not None and len(test_files) > 0:
        if use_maps_dataset:
            test_dataset = AutoEncoderSingleMapsDataset.create_dataset(test_files, is_weights=is_weights, pre_load_data=pre_load_data)
        else:
            test_dataset = AutoEncoderStatsDataset.create_dataset(test_files, is_weights, is_single_layer)
        test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)
        trainer.evaluate(test_dataloader, metrics_funcs=metrics_funcs, metrics_names=metrics_names)


def parse_args():
    parser = argparse.ArgumentParser(description='LSTM Training')
    parser.add_argument('-f', '--folders', nargs="+", type=str, default=None,
                        help='Folders of stats files')
    parser.add_argument('-s', '--model_save_dir', type=str, help='path to all model specific run data')
    parser.add_argument('--test_folders', nargs='+', type=str, default=None)
    parser.add_argument('-data', '--data_source', type=str,
                        help='Data source from weights or gradients')
    parser.add_argument('-vs', '--val_size', type=float, default=0.2)
    parser.add_argument('-ts', '--test_size', type=float, default=0)
    parser.add_argument('-isl', '--is_single_layer', type=str, default='True',
                        help='True for single layer stats, False for multi layer stats')
    parser.add_argument('-bs', '--batch_size', type=int, default=32)
    parser.add_argument('-e', '--epochs', type=int, default=150)
    parser.add_argument('-ae', '--auto_encoder_struct', type=int, nargs=3,
                        help='Auto Encoder Shape to use')

    parser.add_argument('--kfolder', type=str, default=None, help='K fold data folder if K fold is wanted')
    parser.add_argument('--knum', type=int, default=-1, help='number of fold used for K fold')
    parsed_args = parser.parse_args(sys.argv[1:])
    return parsed_args


if __name__ == '__main__':
    get_logger(os.path.basename(__file__).split('.')[0])
    args_ = parse_args()
    model_save_dir_ = args_.model_save_dir if os.path.isabs(args_.model_save_dir) else \
        os.path.join('/sise/group/models_training', args_.model_save_dir)
    train_auto_encoder(train_val_folders=args_.folders, test_folders=args_.test_folders, val_size=args_.val_size,
                       test_size=args_.test_size, epochs=args_.epochs, batch_size=args_.batch_size,
                       data_source=args_.data_source, is_single_layer=bool(strtobool(args_.is_single_layer)),
                       model_save_path=model_save_dir_, layer_ops=args_.auto_encoder_struct,
                       kfold_folder_base=args_.kfolder, kfold_num=args_.knum)
